{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:20:58.273024Z",
     "start_time": "2019-11-30T14:20:58.267998Z"
    }
   },
   "outputs": [],
   "source": [
    "import foolbox\n",
    "import numpy as np\n",
    "import torchvision.models as models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:20:59.311356Z",
     "start_time": "2019-11-30T14:20:58.896522Z"
    }
   },
   "outputs": [],
   "source": [
    "# instantiate model (supports PyTorch, Keras, TensorFlow (Graph and Eager), MXNet and many more)\n",
    "model = models.resnet18(pretrained=True).eval()\n",
    "preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)\n",
    "fmodel = foolbox.models.PyTorchModel(model, bounds=(0, 1), num_classes=1000, preprocessing=preprocessing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:21:00.183939Z",
     "start_time": "2019-11-30T14:20:59.998374Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "攻击前准确率： 0.9375\n"
     ]
    }
   ],
   "source": [
    "# get a batch of images and labels and print the accuracy\n",
    "images, labels = foolbox.utils.samples(dataset='imagenet', batchsize=16, data_format='channels_first', bounds=(0, 1))\n",
    "print('攻击前准确率：', np.mean(fmodel.forward(images).argmax(axis=-1) == labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:21:01.549910Z",
     "start_time": "2019-11-30T14:21:00.860712Z"
    }
   },
   "outputs": [],
   "source": [
    "# apply the attack\n",
    "attack = foolbox.attacks.FGSM(fmodel)\n",
    "adversarials = attack(images, labels)\n",
    "# if the i'th image is misclassfied without a perturbation, then adversarials[i] will be the same as images[i]\n",
    "# if the attack fails to find an adversarial for the i'th image, then adversarials[i] will all be np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:21:01.618063Z",
     "start_time": "2019-11-30T14:21:01.590135Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "攻击后准确率： 0.0\n"
     ]
    }
   ],
   "source": [
    "# Foolbox guarantees that all returned adversarials are in fact in adversarials\n",
    "print('攻击后准确率：', np.mean(fmodel.forward(adversarials).argmax(axis=-1) == labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:21:03.377650Z",
     "start_time": "2019-11-30T14:21:02.695209Z"
    }
   },
   "outputs": [],
   "source": [
    "# You can always get the actual adversarial class that was observed for that sample by Foolbox by\n",
    "# passing `unpack=False` to get the actual `Adversarial` objects:\n",
    "attack = foolbox.attacks.FGSM(fmodel, distance=foolbox.distances.Linf)\n",
    "adversarials = attack(images, labels, unpack=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:30:56.470493Z",
     "start_time": "2019-11-30T14:30:56.460484Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原本标签： [243 559 438 990 949 853 609 609 915 455 541 630 741 471 129  99]\n",
      "攻击后标签： [242 694 711 937 927 706 479 511 672 539 463 636 497 870  89 138]\n",
      "准确度： 0.0\n"
     ]
    }
   ],
   "source": [
    "adversarial_classes = np.asarray([a.adversarial_class for a in adversarials])\n",
    "print('原本标签：', labels)\n",
    "print('攻击后标签：', adversarial_classes)\n",
    "print('准确度：', np.mean(adversarial_classes == labels))  # will always be 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-11-30T14:21:04.783932Z",
     "start_time": "2019-11-30T14:21:04.771915Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0e+00, 8.0e-04, 2.9e-03\n",
      "0 of 16 attacks failed\n",
      "1 of 16 inputs misclassified without perturbation\n"
     ]
    }
   ],
   "source": [
    "# The `Adversarial` objects also provide a `distance` attribute. Note that the distances\n",
    "# can be 0 (misclassified without perturbation) and inf (attack failed).\n",
    "distances = np.asarray([a.distance.value for a in adversarials])\n",
    "print(\"{:.1e}, {:.1e}, {:.1e}\".format(distances.min(), np.median(distances), distances.max()))\n",
    "print(\"{} of {} attacks failed\".format(sum(adv.distance.value == np.inf for adv in adversarials), len(adversarials)))\n",
    "print(\"{} of {} inputs misclassified without perturbation\".format(sum(adv.distance.value == 0 for adv in adversarials), len(adversarials)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
